from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import logging
import time
import os
from collections import defaultdict

import numpy as np
import torch
import h5py

from lib.core.evaluate import accuracy
from lib.core.inference import get_final_preds
from lib.utils.transforms import flip_back
from lib.utils.vis import save_debug_images


logger = logging.getLogger(__name__)


def percent_of_datasource(meta):
    view_0_meta = meta[0]
    batch_num = len(view_0_meta['source'])
    cnt_dict = defaultdict(int)
    for source in view_0_meta['source']:
        cnt_dict[source] += 1
    res_string = ''
    for k, v in cnt_dict.items():
        res_string += '{} {:.1%}\t'.format(k, v / batch_num)
    return res_string


def train(config, train_loader, model, criterion, optimizer, epoch,
          output_dir, writer_dict, rank, device_id):
    device = torch.device("cuda", device_id)

    batch_time = AverageMeter()
    data_time = AverageMeter()

    if rank == 0:
        losses = AverageMeter()
        acc = AverageMeter()

    # switch to train mode
    model.train()

    end = time.time()
    for i, (input_data, target, target_weight, meta) in enumerate(train_loader):
        """
        input: a list of [N, 3, H, W]
        target: a list of [N, 16, h, w]
        weight: a list of [N, 16, 1]
        meta: a list of dictionaries
        """

        # measure data loading time
        data_time.update(time.time() - end)
        input_data = input_data.to(device, non_blocking=False)
        target = target.to(device, non_blocking=False)
        target_weight = target_weight.to(device, non_blocking=False)

        # compute output
        output = model(input_data)
        loss = criterion(output, target, target_weight)

        # compute gradient and do update step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if rank == 0:
            # measure accuracy and record loss
            losses.update(loss.item(), input_data.size(0))

            _, avg_acc, cnt, pred = accuracy(output.detach().cpu().numpy(),
                                             target.detach().cpu().numpy())
            acc.update(avg_acc, cnt)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % config.PRINT_FREQ == 0:
                msg = 'Epoch: [{0}][{1}/{2}]\t' \
                      'Time {batch_time.val:.3f}s ({batch_time.avg:.3f}s)\t' \
                      'LR {lr:.5f}\t' \
                      'Speed {speed:.1f} samples/s\t' \
                      'Data {data_time.val:.3f}s ({data_time.avg:.3f}s)\t' \
                      'Loss {loss.val:.5f} ({loss.avg:.5f})\t' \
                      'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
                          epoch, i, len(train_loader), batch_time=batch_time,
                          lr=optimizer.state_dict()['param_groups'][0]['lr'],
                          speed=input_data.size(0)/batch_time.val,
                          data_time=data_time, loss=losses, acc=acc)
                logger.info(msg)

                writer = writer_dict['writer']
                global_steps = writer_dict['train_global_steps']
                writer.add_scalar('train_loss', losses.val, global_steps)
                writer.add_scalar('train_acc', acc.val, global_steps)
                writer_dict['train_global_steps'] = global_steps + 1

                prefix = '{}_{}'.format(os.path.join(output_dir, 'train'), i)
                save_debug_images(config, input_data, meta, target, pred * 4, output,
                                  prefix)


def validate(config, val_loader, val_dataset, model, criterion, output_dir,
             writer_dict,  # None
             rank, device_id):
    # only rank 0 process will enter this function
    device = torch.device("cuda", device_id)

    batch_time = AverageMeter()
    losses = AverageMeter()
    acc = AverageMeter()

    # switch to evaluate mode
    model.eval()

    num_samples = len(val_dataset)
    all_preds = np.zeros((num_samples, config.MODEL.NUM_JOINTS, 3),
                         dtype=np.float32)
    all_boxes = np.zeros((num_samples, 6))
    all_heatmaps = np.zeros(
        (num_samples, config.MODEL.NUM_JOINTS, config.MODEL.EXTRA.HEATMAP_SIZE[1], config.MODEL.EXTRA.HEATMAP_SIZE[0]),
        dtype=np.float32)

    bbox_ids = np.zeros((num_samples)) 

    image_path = []
    filenames = []
    imgnums = []
    idx = 0
    with torch.no_grad():
        end = time.time()
        for i, (input_data, target, target_weight, meta) in enumerate(val_loader):
            # compute output
            input_data = input_data.to(device, non_blocking=False)
            output = model(input_data)
            if config.TEST.FLIP_TEST:
                # this part is ugly, because pytorch has not supported negative index
                # input_flipped = model(input[:, :, :, ::-1])
                input_flipped = np.flip(input_data.cpu().numpy(), 3).copy()
                input_flipped = torch.from_numpy(input_flipped).to(device, non_blocking=False)
                output_flipped = model(input_flipped)
                output_flipped = flip_back(output_flipped.cpu().numpy(),
                                           val_dataset.flip_pairs)
                output_flipped = torch.from_numpy(output_flipped.copy()).to(device, non_blocking=False)

                # feature is not aligned, shift flipped heatmap for higher accuracy
                if config.TEST.SHIFT_HEATMAP:
                    output_flipped[:, :, :, 1:] = \
                        output_flipped.clone()[:, :, :, 0:-1]
                    # output_flipped[:, :, :, 0] = 0

                output = (output + output_flipped) * 0.5

            target = target.to(device, non_blocking=False)
            target_weight = target_weight.to(device, non_blocking=False)

            loss = criterion(output, target, target_weight)

            num_images = input_data.size(0)
            # measure accuracy and record loss
            losses.update(loss.item(), num_images)
            _, avg_acc, cnt, pred = accuracy(output.cpu().numpy(),
                                             target.cpu().numpy())

            acc.update(avg_acc, cnt)

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            c = meta['center'].numpy()
            s = meta['scale'].numpy()
            score = meta['score'].numpy()
            if 'bbox_id' in meta.keys(): #mmpose的写法
                bbox_ids[idx:idx + num_images] = meta['bbox_id'].numpy()
            if 'track_id' in meta.keys(): #正常写法
                bbox_ids[idx:idx + num_images] = meta['track_id'].numpy()

            preds, maxvals, preds_in_input_space = get_final_preds(
                config, output.clone().cpu().numpy(), c, s)

            all_preds[idx:idx + num_images, :, 0:2] = preds[:, :, 0:2]
            all_preds[idx:idx + num_images, :, 2:3] = maxvals
            # double check this all_boxes parts
            all_boxes[idx:idx + num_images, 0:2] = c[:, 0:2]
            all_boxes[idx:idx + num_images, 2:4] = s[:, 0:2]
            all_boxes[idx:idx + num_images, 4] = np.prod(s*200, 1)
            all_boxes[idx:idx + num_images, 5] = score
            all_heatmaps[idx:idx + num_images] = output.cpu().numpy()
            image_path.extend(meta['image'])
            if config.DATASET.TEST_DATASET[0].DATASET == 'posetrack':
                filenames.extend(meta['filename'])
                imgnums.extend(meta['imgnum'].numpy())

            idx += num_images

            if i % config.PRINT_FREQ == 0:
                msg = 'Test: [{0}/{1}]\t' \
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
                      'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
                          i, len(val_loader), batch_time=batch_time,
                          loss=losses, acc=acc)
                logger.info(msg)

                prefix = '{}_{}'.format(os.path.join(output_dir, 'val'), i)
                save_debug_images(config, input_data, meta, target, preds_in_input_space, output,
                                  prefix)

        msg = 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' \
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t' \
              'Accuracy {acc.val:.3f} ({acc.avg:.3f})'\
            .format(batch_time=batch_time, loss=losses, acc=acc)
        logger.info(msg)

        name_values, perf_indicator = val_dataset.evaluate(
            config, all_preds, output_dir, all_boxes, image_path, 
            filenames, imgnums, bbox_ids
        )

        model_name = config.MODEL.NAME
        if isinstance(name_values, list):
            for name_value in name_values:
                _print_name_value(name_value, model_name)
        else:
            _print_name_value(name_values, model_name)

        # save heatmaps and joint locations
        u2a = val_dataset.u2a_mapping
        u2a = {k: v for k, v in u2a.items() if v != '*'}
        sorted_u2a = sorted(u2a.items(), key=lambda x: x[0])
        u = np.array([mapping[0] for mapping in sorted_u2a])

        # file_name = os.path.join(output_dir, 'heatmaps_locations_%s_%s.h5' % (val_dataset.subset, val_dataset.dataset_type))
        # file = h5py.File(file_name, 'w')
        # file['heatmaps'] = all_heatmaps[:, u, :, :]
        # file['locations'] = all_preds[:, u, :]
        # file['joint_names_order'] = u  # names order in union(mpii) dataset
        # file.close()

    return perf_indicator


# markdown format output
def _print_name_value(name_value, full_arch_name):
    names = name_value.keys()
    values = name_value.values()
    num_values = len(name_value)
    logger.info(
        '| Arch ' +
        ' '.join(['| {}'.format(name) for name in names]) +
        ' |'
    )
    logger.info('|---' * (num_values+1) + '|')
    logger.info(
        '| ' + full_arch_name + ' ' +
        ' '.join(['| {:.3f}'.format(value) for value in values]) +
         ' |'
    )


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count if self.count != 0 else 0
